import torch
import torch.nn.functional as F
import utils.graph_lib as graphs
from functools import partial

def discrete_loss(model, data, cond, graph : graphs.Graph):
    """
    Batch shape: [B, L] int. D given from graph
    """
    cur_shape = data.shape
    eps = graph.delta
    t = (1 - eps) * torch.rand(data.shape[0], device=data.device) + eps
    sigma_int = graph.sigma_int(t)
    data_shaped = data.reshape(cur_shape[0], -1)
    perturbed_data = graph.sample_transition(data_shaped, sigma_int[:, None]).reshape(cur_shape)
    log_score = model(perturbed_data, sigma_int, cond)
    # loss = graph.score_entropy(log_score, sigma_int[:, None], perturbed_data, data)
    # loss = (sigma[:, None] * loss).sum(dim=-1)
    loss = F.cross_entropy(log_score.reshape(-1, log_score.shape[-1]), 
                                     data.flatten().long(), reduction='none')
    return loss.mean()

def get_loss(graph):
    return partial(discrete_loss, graph=graph)